# 导入必要的库
import numpy as np  # 用于多维数组和矩阵运算
import faiss  # 用于高效的相似性搜索和稠密向量聚类
from util import createXY  # 数据生成工具，用于创建训练数据的特征和标签
from sklearn.model_selection import train_test_split  # 用于拆分数据集为训练集和测试集
from sklearn.neighbors import KNeighborsClassifier  # Sklearn中的K近邻分类器
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
from tqdm import tqdm  # 用于在循环中显示进度条
from FaissKNeighbors import FaissKNeighbors  # 自定义的FaissKNeighbors类

# 配置日志，打印正在运行的函数和信息
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# ######### 步骤 1: 获取命令行参数 #########
def get_args():
    """
    解析命令行参数，用于设置训练模式、特征提取方法和使用的库
    """
    parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。')
    parser.add_argument('-m', '--mode', type=str, required=True, choices=['cpu', 'gpu'], help='选择训练模式：CPU或GPU。')
    parser.add_argument('-f', '--feature', type=str, required=True, choices=['flat', 'vgg'], help='选择特征提取方法：flat或vgg。')
    parser.add_argument('-l', '--library', type=str, required=True, choices=['sklearn', 'faiss'], help='选择使用的库：sklearn或faiss。')
    args = parser.parse_args()
    return args

args = get_args()  # 获取命令行参数

# 初始化FAISS所需资源
res = faiss.StandardGpuResources() if args.mode == 'gpu' else None
logging.info(f"选择模式是 {args.mode.upper()}")
logging.info(f"选择特征提取方法是 {args.feature.upper()}")
logging.info(f"选择使用的库是 {args.library.upper()}")

# ######### 步骤 2: 加载和预处理数据 #########
# 使用指定的特征提取方法创建数据集的特征和标签
X, y = createXY(train_folder="../data/train", dest_folder=".", method=args.feature)
X = np.array(X)  # 转换为numpy数组
y = np.array(y)
logging.info("数据加载和预处理完成。")

# 将数据划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
logging.info("数据集划分为训练集和测试集。")

# ######### 步骤 3: 初始化变量和K近邻模型 #########
best_k = -1  # 初始化最佳k值
best_accuracy = 0.0  # 初始化最高准确率
k_values = range(1, 6)  # 定义测试的k值范围

# 根据提供的库选择K近邻算法实现
KNNClass = FaissKNeighbors if args.library == 'faiss' else KNeighborsClassifier
logging.info(f"使用的库为: {args.library.upper()}")

# ######### 步骤 4: 遍历k值，训练并评估模型 #########
for k in tqdm(k_values, desc='寻找最佳k值'):
    # 根据库选择相应的K近邻模型
    knn = KNNClass(k=k, res=res) if args.library == 'faiss' else KNNClass(n_neighbors=k)
    knn.fit(X_train, y_train)  # 训练模型
    accuracy = knn.score(X_test, y_test)  # 计算模型在测试集上的准确率
    
    # 更新最佳k值和准确率
    if accuracy > best_accuracy:
        best_k = k
        best_accuracy = accuracy

# ######### 输出结果 #########
logging.info(f'最佳k值: {best_k}, 最高准确率: {best_accuracy}')

